{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c08309b8-13f0-45bb-a3ea-7b01f05a7346",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import pandas as pd\n",
    "import random\n",
    "import re\n",
    "import subprocess\n",
    "import pyarrow as pa\n",
    "from typing import List\n",
    "import openai\n",
    "import anthropic\n",
    "from dotenv import load_dotenv\n",
    "import gradio as gr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5efd903-e683-4e7f-8747-2998e23a0751",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load API\n",
    "load_dotenv(override=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce49b86a-53f4-4d4f-a721-0d66d9c1b070",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Schema Definition ---\n",
    "SCHEMA = [\n",
    "    (\"Team\", \"TEXT\", '\"Toronto Raptors\"'),\n",
    "    (\"NAME\", \"TEXT\", '\"Otto Porter Jr.\"'),\n",
    "    (\"Jersey\", \"TEXT\", '\"10\", or \"NA\" if null'),\n",
    "    (\"POS\", \"TEXT\", 'One of [\"PF\",\"SF\",\"G\",\"C\",\"SG\",\"F\",\"PG\"]'),\n",
    "    (\"AGE\", \"INT\", 'integer age in years, e.g., 22'),\n",
    "    (\"HT\", \"TEXT\", '`6\\' 7\"` or `6\\' 10\"`'),\n",
    "    (\"WT\", \"TEXT\", '\"232 lbs\"'),\n",
    "    (\"COLLEGE\", \"TEXT\", '\"Michigan\", or \"--\" if null'),\n",
    "    (\"SALARY\", \"TEXT\", '\"$9,945,830\", or \"--\" if null')\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93743e57-c2c5-43e5-8fa1-2e242085db07",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Default schema text for the textbox\n",
    "DEFAULT_SCHEMA_TEXT = \"\\n\".join([f\"{i+1}. {col[0]} ({col[1]}) Example: {col[2]}\" for i, col in enumerate(SCHEMA)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87c58595-6fdd-48f5-a253-ccba352cb385",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Available models\n",
    "MODELS = [\n",
    "    \"gpt-4o\",\n",
    "    \"claude-3-5-haiku-20241022\", \n",
    "    \"ollama:llama3.2:latest\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08cd9ce2-8685-46b5-95d0-811b8025696f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Available file formats\n",
    "FILE_FORMATS = [\".csv\", \".tsv\", \".jsonl\", \".parquet\", \".arrow\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13d68c7f-6f49-4efa-b075-f1e7db2ab527",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_prompt(n: int, schema_text: str, system_prompt: str) -> str:\n",
    "    prompt = f\"\"\"\n",
    "{system_prompt}\n",
    "\n",
    "Generate {n} rows of realistic basketball player data in JSONL format, each line a JSON object with the following fields:\n",
    "\n",
    "{schema_text}\n",
    "\n",
    "Do NOT repeat column values from one row to another.\n",
    "\n",
    "Only output valid JSONL.\n",
    "\"\"\"\n",
    "    return prompt.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdc68f1e-4fbe-45dc-aa36-ce5f718ef6ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- LLM Interface ---\n",
    "def query_model(prompt: str, model: str = \"gpt-4o\") -> List[dict]:\n",
    "    \"\"\"Call OpenAI, Claude, or Ollama\"\"\"\n",
    "    try:\n",
    "        if model.lower().startswith(\"gpt\"):\n",
    "            client = openai.OpenAI(api_key=os.getenv(\"OPENAI_API_KEY\"))\n",
    "            response = client.chat.completions.create(\n",
    "                model=model,\n",
    "                messages=[{\"role\": \"user\", \"content\": prompt}],\n",
    "                temperature=0.7\n",
    "            )\n",
    "            content = response.choices[0].message.content\n",
    "\n",
    "        elif model.lower().startswith(\"claude\"):\n",
    "            client = anthropic.Anthropic(api_key=os.getenv(\"ANTHROPIC_API_KEY\"))\n",
    "            response = client.messages.create(\n",
    "                model=model,\n",
    "                messages=[{\"role\": \"user\", \"content\": prompt}],\n",
    "                max_tokens=4000,\n",
    "                temperature=0.7\n",
    "            )\n",
    "            content = response.content[0].text\n",
    "\n",
    "        elif model.lower().startswith(\"ollama:\"):\n",
    "            ollama_model = model.split(\":\")[1]\n",
    "            result = subprocess.run(\n",
    "                [\"ollama\", \"run\", ollama_model],\n",
    "                input=prompt,\n",
    "                text=True,\n",
    "                capture_output=True\n",
    "            )\n",
    "            if result.returncode != 0:\n",
    "                raise Exception(f\"Ollama error: {result.stderr}\")\n",
    "            content = result.stdout\n",
    "        else:\n",
    "            raise ValueError(\"Unsupported model. Use 'gpt-4.1-mini', 'claude-3-5-haiku-20241022', or 'ollama:llama3.2:latest'\")\n",
    "\n",
    "        # Parse JSONL output\n",
    "        lines = [line.strip() for line in content.strip().splitlines() if line.strip().startswith(\"{\")]\n",
    "        return [json.loads(line) for line in lines]\n",
    "    \n",
    "    except Exception as e:\n",
    "        raise Exception(f\"Model query failed: {str(e)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29e3f5f5-e99c-429c-bea9-69d554c58c9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Output Formatter ---\n",
    "def save_dataset(records: List[dict], file_format: str, filename: str):\n",
    "    df = pd.DataFrame(records)\n",
    "    if file_format == \".csv\":\n",
    "        df.to_csv(filename, index=False)\n",
    "    elif file_format == \".tsv\":\n",
    "        df.to_csv(filename, sep=\"\\t\", index=False)\n",
    "    elif file_format == \".jsonl\":\n",
    "        with open(filename, \"w\") as f:\n",
    "            for record in records:\n",
    "                f.write(json.dumps(record) + \"\\n\")\n",
    "    elif file_format == \".parquet\":\n",
    "        df.to_parquet(filename, engine=\"pyarrow\", index=False)\n",
    "    elif file_format == \".arrow\":\n",
    "        table = pa.Table.from_pandas(df)\n",
    "        with pa.OSFile(filename, \"wb\") as sink:\n",
    "            with pa.ipc.new_file(sink, table.schema) as writer:\n",
    "                writer.write(table)\n",
    "    else:\n",
    "        raise ValueError(\"Unsupported file format\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe258e84-66f4-4fe7-99c0-75b24148e147",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Main Generation Function ---\n",
    "def generate_dataset(schema_text, system_prompt, model, nr_records, file_format, save_as):\n",
    "    try:\n",
    "        # Validation\n",
    "        if nr_records <= 10:\n",
    "            return \"❌ Error: Nr_records must be greater than 10.\", None\n",
    "        \n",
    "        if file_format not in FILE_FORMATS:\n",
    "            return \"❌ Error: Invalid file format specified.\", None\n",
    "        \n",
    "        if not save_as or save_as.strip() == \"\":\n",
    "            save_as = f\"basketball_dataset{file_format}\"\n",
    "        elif not save_as.endswith(file_format):\n",
    "            save_as = save_as + file_format\n",
    "        \n",
    "        # Generate prompt\n",
    "        prompt = get_prompt(nr_records, schema_text, system_prompt)\n",
    "        \n",
    "        # Query model\n",
    "        records = query_model(prompt, model=model)\n",
    "        \n",
    "        if not records:\n",
    "            return \"❌ Error: No valid records generated from the model.\", None\n",
    "        \n",
    "        # Save dataset\n",
    "        save_dataset(records, file_format, save_as)\n",
    "        \n",
    "        # Create preview\n",
    "        df = pd.DataFrame(records)\n",
    "        preview = df.head(10)  # Show first 10 rows\n",
    "        \n",
    "        success_message = f\"✅ Dataset generated successfully!\\n📁 Saved to: {save_as}\\n📊 Generated {len(records)} records\"\n",
    "        \n",
    "        return success_message, preview\n",
    "    \n",
    "    except Exception as e:\n",
    "        return f\"❌ Error: {str(e)}\", None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2405a9d-b4cd-43d9-82f6-ff3512b4541f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Gradio Interface ---\n",
    "def create_interface():\n",
    "    with gr.Blocks(title=\"Dataset Generator\", theme=gr.themes.Soft()) as interface:\n",
    "        gr.Markdown(\"# Dataset Generator\")\n",
    "        gr.Markdown(\"Generate realistic datasets using AI models\")\n",
    "        \n",
    "        with gr.Row():\n",
    "            with gr.Column(scale=2):\n",
    "                schema_input = gr.Textbox(\n",
    "                    label=\"Schema\",\n",
    "                    value=DEFAULT_SCHEMA_TEXT,\n",
    "                    lines=15,\n",
    "                    placeholder=\"Define your dataset schema here...\"\n",
    "                )\n",
    "                \n",
    "                system_prompt_input = gr.Textbox(\n",
    "                    label=\"Prompt\",\n",
    "                    value=\"You are a helpful assistant that generates realistic basketball player data.\",\n",
    "                    lines=1,\n",
    "                    placeholder=\"Enter system prompt for the model...\"\n",
    "                )\n",
    "                \n",
    "                with gr.Row():\n",
    "                    model_dropdown = gr.Dropdown(\n",
    "                        label=\"Model\",\n",
    "                        choices=MODELS,\n",
    "                        value=MODELS[1],  # Default to Claude\n",
    "                        interactive=True\n",
    "                    )\n",
    "                    \n",
    "                    nr_records_input = gr.Number(\n",
    "                        label=\"Nr. records\",\n",
    "                        value=25,\n",
    "                        minimum=11,\n",
    "                        maximum=1000,\n",
    "                        step=1\n",
    "                    )\n",
    "                \n",
    "                with gr.Row():\n",
    "                    file_format_dropdown = gr.Dropdown(\n",
    "                        label=\"File format\",\n",
    "                        choices=FILE_FORMATS,\n",
    "                        value=\".csv\",\n",
    "                        interactive=True\n",
    "                    )\n",
    "                    \n",
    "                    save_as_input = gr.Textbox(\n",
    "                        label=\"Save as\",\n",
    "                        value=\"basketball_dataset\",\n",
    "                        placeholder=\"Enter filename (extension will be added automatically)\"\n",
    "                    )\n",
    "                \n",
    "                generate_btn = gr.Button(\"🚀 Generate\", variant=\"primary\", size=\"lg\")\n",
    "            \n",
    "            with gr.Column(scale=1):\n",
    "                output_status = gr.Textbox(\n",
    "                    label=\"Status\",\n",
    "                    lines=4,\n",
    "                    interactive=False\n",
    "                )\n",
    "                \n",
    "                output_preview = gr.Dataframe(\n",
    "                    label=\"Preview (First 10 rows)\",\n",
    "                    interactive=False,\n",
    "                    wrap=True\n",
    "                )\n",
    "        \n",
    "        # Connect the generate button\n",
    "        generate_btn.click(\n",
    "            fn=generate_dataset,\n",
    "            inputs=[\n",
    "                schema_input,\n",
    "                system_prompt_input, \n",
    "                model_dropdown,\n",
    "                nr_records_input,\n",
    "                file_format_dropdown,\n",
    "                save_as_input\n",
    "            ],\n",
    "            outputs=[output_status, output_preview]\n",
    "        )\n",
    "        \n",
    "        gr.Markdown(\"\"\"\n",
    "        ### 📝 Instructions:\n",
    "        1. **Schema**: Define the structure of your dataset (pre-filled with basketball player schema)\n",
    "        2. **Prompt**: System prompt to guide the AI model\n",
    "        3. **Model**: Choose between GPT, Claude, or Ollama models\n",
    "        4. **Nr. records**: Number of records to generate (minimum 11)\n",
    "        5. **File format**: Choose output format (.csv, .tsv, .jsonl, .parquet, .arrow)\n",
    "        6. **Save as**: Filename (extension added automatically)\n",
    "        7. Click **Generate** to create your dataset\n",
    "        \n",
    "        ### 🔧 Requirements:\n",
    "        - Set up your API keys in `.env` file (`OPENAI_API_KEY`, `ANTHROPIC_API_KEY`)\n",
    "        - For Ollama models, ensure Ollama is installed and running locally\n",
    "        \"\"\")\n",
    "    \n",
    "    return interface"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50fd2b91-2578-4224-b9dd-e28caf6a0a85",
   "metadata": {},
   "outputs": [],
   "source": [
    "interface = create_interface()\n",
    "interface.launch(inbrowser=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
